import numpy as np
import pandas as pd
import feather, os, joblib
import sys, pickle, subprocess
from tqdm import tqdm
from joblib import Parallel, delayed
from collections import Counter
from itertools import combinations

num_users = 94
users = range(num_users)
user_pairs = list(combinations(list(users), 2))

# Read the trial from the file using pickle
with open('trial.pickle', 'rb') as file:
    trial = pickle.load(file)

# Read the data directory from the file using pickle
data_directory = f'train_test_{trial}'
test_dataset = np.load(f'{data_directory}//test_dataset.npy')

# Read the model scores
scores = np.load(f'model_scores_{trial}.npy')
model_scores = {f'{pair[0]}_{pair[1]}': score for pair, score in zip(user_pairs, scores)}

fname = sys.argv[1]
pos = int(sys.argv[2])
test_users = np.load(fname)

def load_model(model_file):
    return joblib.load(model_file)

model_files = os.listdir(f'models_{trial}')

num_cores = 4 # specify the number of cores to use
loaded_models = Parallel(n_jobs=num_cores)(delayed(load_model)(f'models_{trial}/{model_file}')
                                           for model_file
                                           in tqdm(model_files, ncols=100, ascii=' |', leave=False,
                                                   position=1, desc='Load models', unit='models'))

models = {os.path.splitext(os.path.basename(model_file))[0]: model
                 for model_file, model in zip(model_files, loaded_models)}

def process_chunk(X_test, models):
    predicted_user = []
    pbar_models = tqdm(models.items(), desc='Prediction', ncols=100, ascii=' |',
                       position=1, leave=False, unit='models')
    for user_pair, model in pbar_models:
        if model_scores[user_pair] > 0.6:
            pair = [int(x) for x in user_pair.split('_')]
            y_predict = model.predict(X_test)
            if (y_predict == 0).sum() != (y_predict == 1).sum():
                prediction = np.bincount(y_predict).argmax()
                predicted_user.append(pair[prediction])
    predictions = dict(Counter(predicted_user))
    predictions = dict(sorted(predictions.items()))
    return pd.DataFrame.from_records([predictions])

result = pd.DataFrame(columns=[user for user in users])
pbar_test_users = tqdm(test_users, desc='Test user', ascii=' |', leave=False,
                       ncols=100, position=pos+2, unit='user')
for test_user in pbar_test_users:
    X_test = test_dataset[test_user]
    X_test = X_test[~np.isnan(X_test).any(axis=1)]
    result_chunk = process_chunk(X_test, models)
    result = pd.concat([result, result_chunk]).reset_index(drop=True)
feather.write_dataframe(result, f"uis_partition_files//result_{os.getpid()}.feather")